Goto

Collaborating Authors

 induction hypothesis


Multi-head Transformers Provably Learn Symbolic Multi-step Reasoning via Gradient Descent

Neural Information Processing Systems

Transformers have demonstrated remarkable capabilities in multi-step reasoning tasks. However, understandings of the underlying mechanisms by which they acquire these abilities through training remain limited, particularly from a theoretical standpoint. This work investigates how transformers learn to solve symbolic multi-step reasoning problems through chain-of-thought processes, focusing on path-finding in trees. We analyze two intertwined tasks: a backward reasoning task, where the model outputs a path from a goal node to the root, and a more complex forward reasoning task, where the model implements two-stage reasoning by first identifying the goal-to-root path and then reversing it to produce the root-to-goal path. Our theoretical analysis, grounded in the dynamics of gradient descent, shows that trained one-layer transformers can provably solve both tasks with generalization guarantees to unseen trees. In particular, our multi-phase training dynamics for forward reasoning elucidate how different attention heads learn to specialize and coordinate autonomously to solve the two subtasks in a single autoregressive path. These results provide a mechanistic explanation of how trained transformers can implement sequential algorithmic procedures. Moreover, they offer insights into the emergence of reasoning abilities, suggesting that when tasks are structured to take intermediate chain-of-thought steps, even shallow multi-head transformers can effectively solve problems that would otherwise require deeper architectures.


Value Improved Actor Critic Algorithms

Neural Information Processing Systems

To learn approximately optimal acting policies for decision problems, modern Actor Critic algorithms rely on deep Neural Networks (DNNs) to parameterize the acting policy and greedification operators to iteratively improve it. The reliance on DNNs suggests an improvement that is gradient based, which is per step much less greedy than the improvement possible by greedier operators such as the greedy update used by Q-learning algorithms. On the other hand, slow changes to the policy can also be beneficial for the stability of the learning process, resulting in a tradeoff between greedification and stability. To better address this tradeoff, we propose to decouple the acting policy from the policy evaluated by the critic. This allows the agent to separately improve the critic's policy (e.g.


Convexity in Disguise: A Theoretical Framework for Nonconvex Low-Rank Matrix Estimation

arXiv.org Machine Learning

Nonconvex methods have emerged as a dominant approach for low-rank matrix estimation, a problem that arises widely in machine learning and AI for learning and representing high-dimensional data. Existing analyses for these methods often require additional regularization to mitigate nonconvexity, even though such regularization is often unnecessary in practice. Moreover, most analyses rely on problem-specific arguments that are difficult to generalize to more complex settings. In this paper, we develop a theoretical framework for studying nonconvex procedures across a broad class of low-rank matrix estimation problems. Rather than focusing on a specific model, we reveal a fundamental mechanism that explains why nonconvex procedures can behave well in low-rank estimation. Our key device is a {\it benign regularizer} that does not alter the original update rule, but yields an equivalent locally strongly convex formulation of the algorithm. This perspective uncovers a disguised convexity inherent in the nonconvex procedure and provides a new route to theoretical guarantees for nonconvex low-rank matrix estimation.



Fitting trees to ℓ1-hyperbolic distances

Neural Information Processing Systems

Building trees to represent or to fit distances is a critical component of phylogenetic analysis, metric embeddings, approximation algorithms, geometric graph neural nets, and the analysis of hierarchical data. Much of the previous algorithmic work, however, has focused on generic metric spaces (i.e., those with no a priori constraints). Leveraging several ideas from the mathematical analysis of hyperbolic geometry and geometric group theory, we study the tree fitting problem as finding the relation between the hyperbolicity (ultrametricity) vector and the error of tree (ultrametric) embedding. That is, we define a vector of hyperbolicity (ultrametric) values over all triples of points and compare the ℓp norms of this vector with the ℓq norm of the distortion of the best tree fit to the distances. This formulation allows us to define the average hyperbolicity (ultrametricity) in terms of a normalized ℓ1 norm of the hyperbolicity vector. Furthermore, we can interpret the classical tree fitting result of Gromov as a p = q = result. We present an algorithm HCCROOTEDTREEFIT such that the ℓ1 error of the output embedding is analytically bounded in terms of the ℓ1 norm of the hyperbolicity vector (i.e., p = q = 1) and that this result is tight. Furthermore, this algorithm has significantly different theoretical and empirical performance as compared to Gromov's result and related algorithms.



is as powerful as CWL with the generalised update rule HASH ct,ctB(),ctC(),ct# (),ct " ()

Neural Information Processing Systems

A.1 Cellular WLResults In this section, we assume basic familiarity with the WL test and its higher-order variants. For an introduction to these topics, we refer the reader to the survey of Sato [62]. We begin by introducing a few useful concepts. A cellular colouring is a map c that maps a cell complex X and one of its cells to a colour from a fixed colour palette. Let X,Y be two regular cell complexes and c a cellular colouring. We say that X,Y are c-similar, denoted by cX = cY, if the number of cells in X coloured with a given colour equals the number of cells in Y with the same colour. Otherwise, we have cX 6= cY . We emphasise that in this paper we are interested only in colourings c with the property that any two isomorphic cell complexes are c-similar. A cellular colouring c refines a cellular colouring d, denoted by c v d, if for all cell complexes X and Y and all 2 PX and 2 PY, cX = cY implies dX = dY . Additionally, if d v c, we say the two colourings are equivalent and we represent it by c d. We state the following result from Bodnar et al. [8] about simplicial colourings, which we translate here directly to cell complexes. The proof is however, identical, and we refer the reader to their work for that. Let X,Y be any regular cellular complexes with A PX and B PY . Consider two cellular colourings c,d such that c v d.


Assumptions and Likelihoods in More Detail

Neural Information Processing Systems

A.1 Notation Let T be a failure time with CDFF. T's survival function is defined by F = 1 F. We denote failure models by FθT. Let C be a censoring time with CDFG, survival function G, and model GθC. Under right-censoring, define U = min(T,C), = 1 [T C] and we observe (Xi,Ui, i). We use G(t) to denote P(C t).